{
 "cells": [
  {
   "cell_type": "code",
   "id": "9e75ddcf-824d-4c86-aa3e-662f1737c956",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:00.144274Z",
     "iopub.status.busy": "2024-10-01T14:01:00.143525Z",
     "iopub.status.idle": "2024-10-01T14:01:00.171806Z",
     "shell.execute_reply": "2024-10-01T14:01:00.169563Z",
     "shell.execute_reply.started": "2024-10-01T14:01:00.144198Z"
    }
   },
   "source": [
    "%env LLM_API_KEY=替换为自己的key\n",
    "%env LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "af375836-b870-458b-87d1-4e00565977eb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:00.954583Z",
     "iopub.status.busy": "2024-10-01T14:01:00.954389Z",
     "iopub.status.idle": "2024-10-01T14:01:00.956662Z",
     "shell.execute_reply": "2024-10-01T14:01:00.956352Z",
     "shell.execute_reply.started": "2024-10-01T14:01:00.954569Z"
    }
   },
   "outputs": [],
   "source": [
    "%%capture --no-stderr\n",
    "!pip install -U langchain langchain_community langchain_openai pypdf sentence_transformers chromadb shutil openpyxl FlagEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1e2c72b8-ee12-4130-af88-699998aa230c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:01.441684Z",
     "iopub.status.busy": "2024-10-01T14:01:01.441516Z",
     "iopub.status.idle": "2024-10-01T14:01:04.675892Z",
     "shell.execute_reply": "2024-10-01T14:01:04.675396Z",
     "shell.execute_reply.started": "2024-10-01T14:01:01.441669Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.10/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm, trange\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "langchain                     0.2.10\n",
      "langchain_core                0.2.28\n",
      "langchain_community           0.2.9\n",
      "pypdf                         4.3.1\n",
      "sentence_transformers         3.0.1\n",
      "chromadb                      0.5.4\n"
     ]
    }
   ],
   "source": [
    "import langchain, langchain_community, pypdf, sentence_transformers, chromadb, langchain_core\n",
    "\n",
    "for module in (langchain, langchain_core, langchain_community, pypdf, sentence_transformers, chromadb):\n",
    "    print(f\"{module.__name__:<30}{module.__version__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "facc1812-d307-45b6-8390-f90536229eb9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:04.676930Z",
     "iopub.status.busy": "2024-10-01T14:01:04.676674Z",
     "iopub.status.idle": "2024-10-01T14:01:04.679421Z",
     "shell.execute_reply": "2024-10-01T14:01:04.678973Z",
     "shell.execute_reply.started": "2024-10-01T14:01:04.676916Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "841d2b02-ad06-40d2-b11f-c7adccec6ca2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:04.679961Z",
     "iopub.status.busy": "2024-10-01T14:01:04.679842Z",
     "iopub.status.idle": "2024-10-01T14:01:04.697603Z",
     "shell.execute_reply": "2024-10-01T14:01:04.695217Z",
     "shell.execute_reply.started": "2024-10-01T14:01:04.679949Z"
    }
   },
   "outputs": [],
   "source": [
    "expr_version = 'retrieval_v7_hyde'\n",
    "\n",
    "preprocess_output_dir = os.path.join(os.path.pardir, 'outputs', 'v1_20240713')\n",
    "expr_dir = os.path.join(os.path.pardir, 'experiments', expr_version)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf7e81e3-4c82-4842-aef5-7592caaf1d39",
   "metadata": {},
   "source": [
    "# 读取文档"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "da15f02e-3131-43fb-81c5-f4da615c449b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:04.702833Z",
     "iopub.status.busy": "2024-10-01T14:01:04.702125Z",
     "iopub.status.idle": "2024-10-01T14:01:06.263850Z",
     "shell.execute_reply": "2024-10-01T14:01:06.263323Z",
     "shell.execute_reply.started": "2024-10-01T14:01:04.702764Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_community.document_loaders import PyPDFLoader\n",
    "\n",
    "loader = PyPDFLoader(os.path.join(os.path.pardir, 'data', '2024全球经济金融展望报告.pdf'))\n",
    "documents = loader.load()\n",
    "\n",
    "qa_df = pd.read_excel(os.path.join(preprocess_output_dir, 'question_answer.xlsx'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "841ec659-4ad7-4e1f-b1ea-3477bf97fde3",
   "metadata": {},
   "source": [
    "# 文档切分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "74fe856a-7c19-4c3c-bb30-7abfa6298f74",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:06.264635Z",
     "iopub.status.busy": "2024-10-01T14:01:06.264408Z",
     "iopub.status.idle": "2024-10-01T14:01:06.270936Z",
     "shell.execute_reply": "2024-10-01T14:01:06.270608Z",
     "shell.execute_reply.started": "2024-10-01T14:01:06.264621Z"
    }
   },
   "outputs": [],
   "source": [
    "from uuid import uuid4\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "\n",
    "def split_docs(documents, filepath, chunk_size=400, chunk_overlap=40, seperators=['\\n\\n\\n', '\\n\\n'], force_split=False):\n",
    "    if os.path.exists(filepath) and not force_split:\n",
    "        print('found cache, restoring...')\n",
    "        return pickle.load(open(filepath, 'rb'))\n",
    "\n",
    "    splitter = RecursiveCharacterTextSplitter(\n",
    "        chunk_size=chunk_size,\n",
    "        chunk_overlap=chunk_overlap,\n",
    "        separators=seperators\n",
    "    )\n",
    "    split_docs = splitter.split_documents(documents)\n",
    "    for chunk in split_docs:\n",
    "        chunk.metadata['uuid'] = str(uuid4())\n",
    "\n",
    "    pickle.dump(split_docs, open(filepath, 'wb'))\n",
    "\n",
    "    return split_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "aa25540d-0504-4ae7-9804-9e3862b132d5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:06.271519Z",
     "iopub.status.busy": "2024-10-01T14:01:06.271396Z",
     "iopub.status.idle": "2024-10-01T14:01:06.282320Z",
     "shell.execute_reply": "2024-10-01T14:01:06.281909Z",
     "shell.execute_reply.started": "2024-10-01T14:01:06.271507Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found cache, restoring...\n"
     ]
    }
   ],
   "source": [
    "splitted_docs = split_docs(documents, os.path.join(preprocess_output_dir, 'split_docs.pkl'), chunk_size=500, chunk_overlap=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "220dbc3a-fceb-4e49-a3f1-01e16660b2a6",
   "metadata": {},
   "source": [
    "# 检索"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8598a11c-25d8-4af1-a98b-06a8c394e261",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:06.283284Z",
     "iopub.status.busy": "2024-10-01T14:01:06.283162Z",
     "iopub.status.idle": "2024-10-01T14:01:06.297483Z",
     "shell.execute_reply": "2024-10-01T14:01:06.296976Z",
     "shell.execute_reply.started": "2024-10-01T14:01:06.283272Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cuda\n"
     ]
    }
   ],
   "source": [
    "from langchain.embeddings import HuggingFaceBgeEmbeddings\n",
    "import torch\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'device: {device}')\n",
    "\n",
    "def get_embeddings(model_path):\n",
    "    embeddings = HuggingFaceBgeEmbeddings(\n",
    "        model_name=model_path,\n",
    "        model_kwargs={'device': device},\n",
    "        encode_kwargs={'normalize_embeddings': True},\n",
    "        # show_progress=True\n",
    "        query_instruction='为这个句子生成表示以用于检索相关文章：'\n",
    "    )\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "663ef1a4-5866-4f6b-8d9d-4724f62142cb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:07.508728Z",
     "iopub.status.busy": "2024-10-01T14:01:07.508429Z",
     "iopub.status.idle": "2024-10-01T14:01:07.514319Z",
     "shell.execute_reply": "2024-10-01T14:01:07.513300Z",
     "shell.execute_reply.started": "2024-10-01T14:01:07.508701Z"
    }
   },
   "outputs": [],
   "source": [
    "import shutil\n",
    "from langchain_community.vectorstores import Chroma\n",
    "\n",
    "# 可以替换为本机路径\n",
    "model_path = 'BAAI/bge-large-zh-v1.5'\n",
    "\n",
    "def get_vector_db(embeddings, docs, db_name):\n",
    "    persist_directory = os.path.join(expr_dir, 'chroma', db_name)\n",
    "    shutil.rmtree(persist_directory, ignore_errors=True)\n",
    "\n",
    "    vector_db = Chroma.from_documents(\n",
    "        splitted_docs,\n",
    "        embedding=embeddings,\n",
    "        persist_directory=persist_directory\n",
    "    )\n",
    "    return vector_db"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b03e3382-39e9-4932-a265-69b811041629",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:09.677519Z",
     "iopub.status.busy": "2024-10-01T14:01:09.677031Z",
     "iopub.status.idle": "2024-10-01T14:01:09.687722Z",
     "shell.execute_reply": "2024-10-01T14:01:09.685751Z",
     "shell.execute_reply.started": "2024-10-01T14:01:09.677475Z"
    }
   },
   "outputs": [],
   "source": [
    "test_df = qa_df[(qa_df['dataset'] == 'test') & (qa_df['qa_type'] == 'detailed')]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b2e4b81-bd77-45e7-bfb0-ae321e51fe90",
   "metadata": {},
   "source": [
    "## 不使用HyDE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "11844472-9206-4411-bd5f-e0c556f0ac73",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:12.589942Z",
     "iopub.status.busy": "2024-10-01T14:01:12.589777Z",
     "iopub.status.idle": "2024-10-01T14:01:27.132752Z",
     "shell.execute_reply": "2024-10-01T14:01:27.130292Z",
     "shell.execute_reply.started": "2024-10-01T14:01:12.589929Z"
    }
   },
   "outputs": [],
   "source": [
    "embeddings = get_embeddings(model_path)\n",
    "vector_db = get_vector_db(embeddings, splitted_docs, 'vanilla_retrieval')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "32c3ad14-b217-44aa-bdb9-909b9d559668",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:39.238422Z",
     "iopub.status.busy": "2024-10-01T14:01:39.238215Z",
     "iopub.status.idle": "2024-10-01T14:01:39.242918Z",
     "shell.execute_reply": "2024-10-01T14:01:39.242501Z",
     "shell.execute_reply.started": "2024-10-01T14:01:39.238404Z"
    }
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "def get_hit_stat_df(vector_db, top_k_arr=list(range(1, 9))):\n",
    "    hit_stat_data = []\n",
    "    pbar = tqdm(total=len(top_k_arr) * len(test_df))\n",
    "    for k in top_k_arr:\n",
    "        pbar.set_description(f'k={k}')\n",
    "        # retriever = vector_db.as_retriever(search_kwargs={'k': k})\n",
    "        for idx, row in test_df.iterrows():\n",
    "            question = row['question']\n",
    "            true_uuid = row['uuid']\n",
    "            chunks = vector_db.similarity_search(question, k=k)\n",
    "            retrieved_uuids = [doc.metadata['uuid'] for doc in chunks]\n",
    "\n",
    "            hit_stat_data.append({\n",
    "                'question': question,\n",
    "                'top_k': k,\n",
    "                'hit': int(true_uuid in retrieved_uuids),\n",
    "                'retrieved_chunks': len(chunks)\n",
    "            })\n",
    "            pbar.update(1)\n",
    "    hit_stat_df = pd.DataFrame(hit_stat_data)\n",
    "    return hit_stat_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "93dd701a-bd89-44db-a954-c78dd7001e38",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:01:40.322993Z",
     "iopub.status.busy": "2024-10-01T14:01:40.322207Z",
     "iopub.status.idle": "2024-10-01T14:01:58.071781Z",
     "shell.execute_reply": "2024-10-01T14:01:58.071346Z",
     "shell.execute_reply.started": "2024-10-01T14:01:40.322894Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9b83bbd9b38425989c7e4e0113fff09",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/744 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "orig_query_hit_stat_df = get_hit_stat_df(vector_db)\n",
    "orig_query_hit_stat_df['HyDE'] = 'w/o'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd9f5f38-28b2-4b77-a101-8694e505678f",
   "metadata": {},
   "source": [
    "## 使用HyDE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f25c0878-151a-4d06-901d-3bf4add8a86a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:02:12.199045Z",
     "iopub.status.busy": "2024-10-01T14:02:12.198220Z",
     "iopub.status.idle": "2024-10-01T14:02:26.373926Z",
     "shell.execute_reply": "2024-10-01T14:02:26.372639Z",
     "shell.execute_reply.started": "2024-10-01T14:02:12.198968Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from langchain_community.chat_models import ChatOllama\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain.prompts import PromptTemplate\n",
    "from langchain.chains import HypotheticalDocumentEmbedder\n",
    "import re\n",
    "\n",
    "llm = ChatOllama(base_url='http://localhost:11434', model='qwen2:7b-instruct')\n",
    "\n",
    "prompt = PromptTemplate(\n",
    "    input_variables=['question'],\n",
    "    template = \"\"\"你是一位精通财经的专业分析师，请基于用户提问，回答问题。\n",
    "问题：{question}\n",
    "请回答：\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "hyde_embedding = HypotheticalDocumentEmbedder.from_llm(llm, base_embeddings=embeddings, custom_prompt=prompt)\n",
    "hyde_vector_db = get_vector_db(hyde_embedding, splitted_docs, 'hyde')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ce49bf41-447b-470c-a243-99b583bdd773",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:02:28.361191Z",
     "iopub.status.busy": "2024-10-01T14:02:28.360771Z",
     "iopub.status.idle": "2024-10-01T14:02:31.266461Z",
     "shell.execute_reply": "2024-10-01T14:02:31.265977Z",
     "shell.execute_reply.started": "2024-10-01T14:02:28.361152Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='我是阿里云开发的一款超大规模语言模型，我叫通义千问。作为一个AI助手，我的目标是帮助用户获得准确、有用的信息，解决他们的问题和困惑。我可以回答各种领域的问题，提供代码实现、解释概念、提供建议等。请随时告诉我您需要帮助的内容，我会尽力提供支持。', response_metadata={'model': 'qwen2:7b-instruct', 'created_at': '2024-10-01T14:02:31.26354926Z', 'message': {'role': 'assistant', 'content': ''}, 'done_reason': 'stop', 'done': True, 'total_duration': 2895563597, 'load_duration': 1588807712, 'prompt_eval_count': 10, 'prompt_eval_duration': 42270000, 'eval_count': 71, 'eval_duration': 1219858000}, id='run-9ab48c32-3bd3-46b0-8887-98c3108a192e-0')"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm.invoke('你是谁')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d7d1bf20-7169-45d0-958e-1a4bece5a8a4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:02:38.818523Z",
     "iopub.status.busy": "2024-10-01T14:02:38.817739Z",
     "iopub.status.idle": "2024-10-01T14:02:38.834095Z",
     "shell.execute_reply": "2024-10-01T14:02:38.831771Z",
     "shell.execute_reply.started": "2024-10-01T14:02:38.818451Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_hyde_hit_stat_df(vector_db, k):\n",
    "    hit_stat_data = []\n",
    "    pbar = tqdm(total=1 * len(test_df))\n",
    "    pbar.set_description(f'k={k}')\n",
    "\n",
    "    for idx, row in test_df.iterrows():\n",
    "        question = row['question']\n",
    "        true_uuid = row['uuid']\n",
    "\n",
    "        chunks = vector_db.similarity_search(question, k=k)\n",
    "        retrieved_uuids = [doc.metadata['uuid'] for doc in chunks]\n",
    "\n",
    "        hit_stat_data.append({\n",
    "            'question': question,\n",
    "            'top_k': k,\n",
    "            'hit': int(true_uuid in retrieved_uuids),\n",
    "            'retrieved_chunks': len(chunks)\n",
    "        })\n",
    "        pbar.update(1)\n",
    "    hit_stat_df = pd.DataFrame(hit_stat_data)\n",
    "    return hit_stat_df"
   ]
  },
  {
   "cell_type": "code",
   "id": "a9c8ee12-b922-42bf-b80a-283cf65ee1b2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-29T07:05:42.339760Z",
     "iopub.status.busy": "2024-09-29T07:05:42.339573Z",
     "iopub.status.idle": "2024-09-29T08:12:44.883205Z",
     "shell.execute_reply": "2024-09-29T08:12:44.882698Z",
     "shell.execute_reply.started": "2024-09-29T07:05:42.339744Z"
    }
   },
   "source": [
    "from concurrent.futures import ThreadPoolExecutor\n",
    "\n",
    "top_k_arr=list(range(1, 9))\n",
    "\n",
    "hyde_hit_stat_data = []\n",
    "with ThreadPoolExecutor(max_workers=2) as executor:\n",
    "    futures = {k: executor.submit(get_hyde_hit_stat_df, hyde_vector_db, k) for k in top_k_arr}\n",
    "    for k in futures:\n",
    "        hyde_hit_stat_data.append(futures[k].result())\n",
    "hyde_hit_stat_df = pd.concat(hyde_hit_stat_data)\n",
    "hyde_hit_stat_df['HyDE'] = 'w/'"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "899d1393-3ba3-48f7-b10e-c5c3de8a8927",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-29T08:12:44.883839Z",
     "iopub.status.busy": "2024-09-29T08:12:44.883680Z",
     "iopub.status.idle": "2024-09-29T08:12:44.886735Z",
     "shell.execute_reply": "2024-09-29T08:12:44.886282Z",
     "shell.execute_reply.started": "2024-09-29T08:12:44.883826Z"
    }
   },
   "outputs": [],
   "source": [
    "hit_stat_df = pd.concat([orig_query_hit_stat_df, hyde_hit_stat_df])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d4890789-a44c-41de-b17f-0ff505788494",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-29T08:12:44.887294Z",
     "iopub.status.busy": "2024-09-29T08:12:44.887172Z",
     "iopub.status.idle": "2024-09-29T08:12:44.904902Z",
     "shell.execute_reply": "2024-09-29T08:12:44.904470Z",
     "shell.execute_reply.started": "2024-09-29T08:12:44.887282Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>HyDE</th>\n",
       "      <th>top_k</th>\n",
       "      <th>hit_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>w/</td>\n",
       "      <td>1</td>\n",
       "      <td>0.440860</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>w/</td>\n",
       "      <td>2</td>\n",
       "      <td>0.559140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>w/</td>\n",
       "      <td>3</td>\n",
       "      <td>0.666667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>w/</td>\n",
       "      <td>4</td>\n",
       "      <td>0.698925</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>w/</td>\n",
       "      <td>5</td>\n",
       "      <td>0.720430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>w/</td>\n",
       "      <td>6</td>\n",
       "      <td>0.774194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>w/</td>\n",
       "      <td>7</td>\n",
       "      <td>0.763441</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>w/</td>\n",
       "      <td>8</td>\n",
       "      <td>0.827957</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>w/o</td>\n",
       "      <td>1</td>\n",
       "      <td>0.462366</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>w/o</td>\n",
       "      <td>2</td>\n",
       "      <td>0.591398</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>w/o</td>\n",
       "      <td>3</td>\n",
       "      <td>0.688172</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>w/o</td>\n",
       "      <td>4</td>\n",
       "      <td>0.774194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>w/o</td>\n",
       "      <td>5</td>\n",
       "      <td>0.806452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>w/o</td>\n",
       "      <td>6</td>\n",
       "      <td>0.817204</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>w/o</td>\n",
       "      <td>7</td>\n",
       "      <td>0.838710</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>w/o</td>\n",
       "      <td>8</td>\n",
       "      <td>0.849462</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   HyDE  top_k  hit_rate\n",
       "0    w/      1  0.440860\n",
       "1    w/      2  0.559140\n",
       "2    w/      3  0.666667\n",
       "3    w/      4  0.698925\n",
       "4    w/      5  0.720430\n",
       "5    w/      6  0.774194\n",
       "6    w/      7  0.763441\n",
       "7    w/      8  0.827957\n",
       "8   w/o      1  0.462366\n",
       "9   w/o      2  0.591398\n",
       "10  w/o      3  0.688172\n",
       "11  w/o      4  0.774194\n",
       "12  w/o      5  0.806452\n",
       "13  w/o      6  0.817204\n",
       "14  w/o      7  0.838710\n",
       "15  w/o      8  0.849462"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hit_stat_df.groupby(['HyDE', 'top_k'])['hit'].mean().reset_index().rename(columns={'hit': 'hit_rate'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "b0b086d1-6cec-4743-8df6-2ab3b1593689",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-29T08:12:44.905441Z",
     "iopub.status.busy": "2024-09-29T08:12:44.905318Z",
     "iopub.status.idle": "2024-09-29T08:12:45.305164Z",
     "shell.execute_reply": "2024-09-29T08:12:45.304734Z",
     "shell.execute_reply.started": "2024-09-29T08:12:44.905428Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: xlabel='top_k', ylabel='hit'>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGxCAYAAACeKZf2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwYklEQVR4nO3de1TU9b7/8deAXMQLqAiooZiaiKkUBKGZ7CI55rLcnd2mtgaRes4qKZOVJ+kC2UXsshF3eSRN1C5ubbe7WBlWcxrKHYWhttW8ZKWYOaC/UhQL3MP8/mg5NRs0L8x8hy/Px1rftZzPfD7zfX9wFS8/3893vhan0+kUAACASfgZXQAAAEBrItwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABTIdwAAABT6WB0Ad7W1NSk7777Tl26dJHFYjG6HAAAcAacTqeOHj2q3r17y8/v9Gsz7S7cfPfdd4qOjja6DAAAcA727dunCy644LR92l246dKli6Sffzhdu3Y1uBoAAHAm6urqFB0d7fo9fjrtLtycvBTVtWtXwg0AAG3MmWwpYUMxAAAwFcINAAAwFcINAAAwlXa35+ZMORwOnThxwugyfFJAQID8/f2NLgMAgBYRbv6N0+mU3W7X4cOHjS7Fp4WFhSkqKorvCgIA+BzCzb85GWwiIiIUEhLCL+9/43Q6dfz4cdXW1kqSevXqZXBFAAC4I9z8isPhcAWbHj16GF2Oz+rYsaMkqba2VhEREVyiAgD4FDYU/8rJPTYhISEGV+L7Tv6M2JcEAPA1hJsWcCnqt/EzAgD4KsINAAAwFcKNh916662aOHFis3abzSaLxXLGd2VZLBbX0alTJw0aNEi33nqrqqqqWvzclg673d4KMwIAwLcRbtqQZcuW6cCBA9q2bZsWLlyoY8eOKTk5Wc8//3yzvjt37tSBAwfcjoiICAOqBgDAuwg3Bquvr1fXrl31yiuvuLW//vrr6tSpk44ePepqO/ndMjExMRo7dqxeeeUVTZo0STk5Ofrhhx/cxkdERCgqKsrt8PPjrxsAYH78tjNYp06ddNNNN2nZsmVu7cuWLdMf/vCH33y0+8yZM3X06FG99957niwTAIA2g++58YK33npLnTt3dmtzOByuP0+dOlUjR47UgQMH1KtXL9XW1mrt2rV6//33f/OzY2NjJUl79uxxa7/gggvcXvfr10/btm07xxkAANB2EG684He/+50WLVrk1vbpp59q8uTJkqSkpCQNHTpUK1as0OzZs/Xiiy+qX79+uvLKK3/zs51Op6Tmt2Z/9NFHbqs+AQEB5zsNAICPS5jVfA9ma6t6MtPj5zhfhBsv6NSpkwYOHOjW9u2337q9njp1qhYuXKjZs2dr2bJlys7OPqPvktm+fbskqX///m7t/fv3V1hY2PkVDgBAG8SeGx8xefJk7d27V3/5y1/0xRdfKCsr64zGFRcXq2vXrkpLS/NwhQAAtA2s3PiIbt266YYbbtCsWbM0duzYZntmJOnw4cOy2+1qaGjQrl279Oyzz+r111/X888/32yVpra2Vj/99JNbW48ePbg8BQAwPcKND5kyZYpWrlyp2267rcX3s7OzJUnBwcHq06ePrrjiClVWVurSSy9t1nfw4MHN2ioqKnT55Ze3btEA4IPYe9K+EW48bPny5S22p6amujYDn7R//3716NFD119/fbP+/973VFr6XAAA2hPCjQ84fvy4Dhw4oHnz5um///u/FRgYaHRJAAC0WWwo9gFPPPGEYmNjFRUVpby8PKPLAQCgTWPlxgc89NBDeuihh4wuAwCA31T98DCPn6Nv/pbzGs/KDQAAMBXCDQAAMBXCDQAAMBXCDQAAMBU2FAOAifFldmiPWLkBAACmQrgBAACmYni4WbhwoWJiYhQcHKzk5GRVVlaetn9xcbEGDx6sjh07Kjo6WjNnzmz2gEgAANB+GbrnZvXq1crNzVVJSYmSk5NVXFys9PR07dy5UxEREc36r1y5UrNnz1ZpaalGjhypXbt26dZbb5XFYlFRUZFHa/XGdetf8/Q17PLyck2ePFn79u3z6HkAX8HeE6D9MHTlpqioSNOmTVN2drbi4uJUUlKikJAQlZaWttj/448/1qhRo/SnP/1JMTExGjt2rG6++ebfXO1Bc2+88YYmTJhgdBkAALQ6w8JNY2OjqqqqlJaW9ksxfn5KS0tTRUVFi2NGjhypqqoqV5j5+uuvtXbtWl177bWnPE9DQ4Pq6urcDrN56623FBYWJofDIUnavHmzLBaLZs+e7eozdepUTZ482fV6zZo1uu666yT9/DO66667FBERoeDgYF1xxRXasGGDdycBAEArMSzcHDp0SA6HQ5GRkW7tkZGRstvtLY7505/+pIcfflhXXHGFAgICNGDAAKWmpuq+++475XkKCwsVGhrqOqKjo1t1Hr5g9OjROnr0qDZt2iTp50tO4eHhstlsrj7l5eVKTU2VJG3btk21tbW66qqrJEn/8z//o7///e9asWKFNm7cqIEDByo9PV3ff/+9t6cCAMB5M3xD8dmw2WyaO3eu/vd//1cbN27Uq6++qrfffluPPPLIKcfk5eXpyJEjrsOMe0xCQ0MVHx/vCjM2m00zZ87Upk2bdOzYMe3fv1+7d+/WmDFjJP18SSo9PV2BgYGqr6/XokWL9OSTT2rcuHGKi4vTkiVL1LFjRy1dutTAWQEAcG4MCzfh4eHy9/dXTU2NW3tNTY2ioqJaHPPggw/qlltu0dSpUzVs2DD9/ve/19y5c1VYWKimpqYWxwQFBalr165uhxmNGTNGNptNTqdTH330kW644QYNGTJE69evV3l5uXr37q1BgwZJ+jncnLwk9dVXX+nEiRMaNWqU67MCAgKUlJSk7du3GzIXAADOh2HhJjAwUAkJCbJara62pqYmWa1WpaSktDjm+PHj8vNzL9nf31+S5HQ6PVdsG5Camqr169fr888/V0BAgGJjY5Wamiqbzaby8nLXqs2BAwe0adMmjR8/3uCKAQDwDEMvS+Xm5mrJkiVasWKFtm/frttvv1319fXKzs6WJGVmZiovL8/Vf8KECVq0aJFWrVqlb775Ru+9954efPBBTZgwwRVy2quT+27mz5/vCjInw43NZnPtt3nzzTc1cuRIde/eXZI0YMAABQYG6h//+Ifrs06cOKENGzYoLi7O6/MAAOB8Gfo9NxkZGTp48KDy8/Nlt9sVHx+vsrIy1ybj6upqt5WaBx54QBaLRQ888ID279+vnj17asKECXrssceMmoLP6Natm4YPH66XXnpJzzzzjCTpyiuv1B//+EedOHHCFXh+fZeUJHXq1Em33367Zs2ape7du6tv37564okndPz4cU2ZMsWQuQAAcD4Mf3BmTk6OcnJyWnzv13f7SFKHDh1UUFCggoICL1Tmri18OdeYMWO0efNm1ypN9+7dFRcXp5qaGg0ePFj19fWyWq0qLi52Gzdv3jw1NTXplltu0dGjR5WYmKh169apW7du3p8EAADnqU3dLYXTKy4ultPpVGxsrKtt8+bNOnDggCRp3bp16t+/vwYOHOg2Ljg4WH/5y1908OBB/fTTT1q/fr0uu+wyr9YOAEBrIdy0I507d9bjjz9udBkAAHiU4Zel4D1jx441ugQAADyOlRsAAGAqrNwA7QxPx0Zrq354mMfP0Td/i8fPAfNg5QYAAJgK4QYAAJgK4QYAAJgKe24AADgH7DXyXazcAAAAUyHcAAAAU+Gy1BnyxvLjr3l6KbK8vFyTJ0/Wvn37PHoeAAC8jZWbduqNN97QhAkTjC4DAIBWR7gxgbfeekthYWFyOBySfn5YpsVi0ezZs119pk6dqsmTJ7ter1mzRtddd53XawUAwNMINyYwevRoHT16VJs2bZL08yWn8PBw2Ww2V5/y8nKlpqZKkrZt26ba2lpdddVVBlQLAIBnEW5MIDQ0VPHx8a4wY7PZNHPmTG3atEnHjh3T/v37tXv3bo0ZM0bSz5ek0tPTFRgYaGDVAAB4BuHGJMaMGSObzSan06mPPvpIN9xwg4YMGaL169ervLxcvXv31qBBgyT9HG64JAUAMCvuljKJ1NRUlZaW6vPPP1dAQIBiY2OVmpoqm82mH374wbVqc+DAAW3atEnjx483uGIAADyDlRuTOLnvZv78+a4gczLc2Gw2136bN998UyNHjlT37t0NrBYAAM8h3JhEt27dNHz4cL300kuuIHPllVdq48aN2rVrlyvwcJcUAMDsuCx1htrC8z3GjBmjzZs3u8JN9+7dFRcXp5qaGg0ePFj19fWyWq0qLi42tE4AADyJlRsTKS4ultPpVGxsrKtt8+bNOnDggCRp3bp16t+/vwYOHGhUiQAAeBzhph3p3LmzHn/8caPLAADAo7gs1Y6MHTvW6BIAAPA4Vm4AAICpEG4AAICpcFmqBU6n0+gSfB4/I6C56oeHefwcbeHOTcBorNz8SkBAgCTp+PHjBlfi+07+jE7+zAAA8BWs3PyKv7+/wsLCVFtbK0kKCQmRxWIxuCrf4nQ6dfz4cdXW1iosLEz+/v5GlwQAgBvCzb+JioqSJFfAQcvCwsJcPysAAHyJT4SbhQsX6sknn5TdbteIESP09NNPKykpqcW+qampKi8vb9Z+7bXX6u233z7vWiwWi3r16qWIiAidOHHivD/PjAICAlixAQD4LMPDzerVq5Wbm6uSkhIlJyeruLhY6enp2rlzpyIiIpr1f/XVV9XY2Oh6/f/+3//TiBEjdOONN7ZqXf7+/vwCBwCgDTJ8Q3FRUZGmTZum7OxsxcXFqaSkRCEhISotLW2xf/fu3RUVFeU63nvvPYWEhLR6uAEAAG2ToeGmsbFRVVVVSktLc7X5+fkpLS1NFRUVZ/QZS5cu1U033aROnTp5qkwAANCGGHpZ6tChQ3I4HIqMjHRrj4yM1I4dO35zfGVlpbZu3aqlS5eesk9DQ4MaGhpcr+vq6s69YAAA4PMM33NzPpYuXaphw4adcvOxJBUWFmrOnDlerAoAX2YHwEiGXpYKDw+Xv7+/ampq3Npramp+8zbj+vp6rVq1SlOmTDltv7y8PB05csR17Nu377zrBgAAvsvQcBMYGKiEhARZrVZXW1NTk6xWq1JSUk479m9/+5saGho0efLk0/YLCgpS165d3Q4AAGBehl+Wys3NVVZWlhITE5WUlKTi4mLV19crOztbkpSZmak+ffqosLDQbdzSpUs1ceJE9ejRw4iyAQCAjzI83GRkZOjgwYPKz8+X3W5XfHy8ysrKXJuMq6ur5efnvsC0c+dOrV+/Xu+++64RJcMkEmY97/FzVD2Z6fFzAADcGR5uJCknJ0c5OTktvmez2Zq1DR48mKdSAwCAFhn+JX4AAACtiXADAABMhXADAABMhXADAABMhXADAABMhXADAABMhXADAABMxSe+5wYwKx4gCQDex8oNAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcINAAAwFcPDzcKFCxUTE6Pg4GAlJyersrLytP0PHz6s6dOnq1evXgoKCtJFF12ktWvXeqlaAADg6zoYefLVq1crNzdXJSUlSk5OVnFxsdLT07Vz505FREQ069/Y2KhrrrlGEREReuWVV9SnTx/t3btXYWFh3i8eAAD4JEPDTVFRkaZNm6bs7GxJUklJid5++22VlpZq9uzZzfqXlpbq+++/18cff6yAgABJUkxMjDdLBgAAPs6wy1KNjY2qqqpSWlraL8X4+SktLU0VFRUtjlmzZo1SUlI0ffp0RUZG6uKLL9bcuXPlcDi8VTYAAPBxhq3cHDp0SA6HQ5GRkW7tkZGR2rFjR4tjvv76a/3f//2fJk2apLVr12r37t264447dOLECRUUFLQ4pqGhQQ0NDa7XdXV1rTcJAADgcwzfUHw2mpqaFBERocWLFyshIUEZGRm6//77VVJScsoxhYWFCg0NdR3R0dFerBgAAHibYeEmPDxc/v7+qqmpcWuvqalRVFRUi2N69eqliy66SP7+/q62IUOGyG63q7GxscUxeXl5OnLkiOvYt29f600CAAD4HMPCTWBgoBISEmS1Wl1tTU1NslqtSklJaXHMqFGjtHv3bjU1Nbnadu3apV69eikwMLDFMUFBQeratavbAQAAzMvQy1K5ublasmSJVqxYoe3bt+v2229XfX296+6pzMxM5eXlufrffvvt+v777zVjxgzt2rVLb7/9tubOnavp06cbNQUAAOBjDL0VPCMjQwcPHlR+fr7sdrvi4+NVVlbm2mRcXV0tP79f8ld0dLTWrVunmTNnavjw4erTp49mzJihe++916gpmELCrOc9fo6qJzM9fg4AACSDw40k5eTkKCcnp8X3bDZbs7aUlBR98sknHq4KAAC0VW3qbikAAIDfQrgBAACmQrgBAACmQrgBAACmQrgBAACmQrgBAACmQrgBAACmQrgBAACmQrgBAACmQrgBAACmQrgBAACmYvizpdA+VD88zOPn6Ju/xePnAAD4PlZuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqRBuAACAqfhEuFm4cKFiYmIUHBys5ORkVVZWnrLv8uXLZbFY3I7g4GAvVgsAAHyZ4eFm9erVys3NVUFBgTZu3KgRI0YoPT1dtbW1pxzTtWtXHThwwHXs3bvXixUDAABfZni4KSoq0rRp05Sdna24uDiVlJQoJCREpaWlpxxjsVgUFRXlOiIjI71YMQAA8GWGhpvGxkZVVVUpLS3N1ebn56e0tDRVVFScctyxY8fUr18/RUdH6/rrr9e2bdtO2behoUF1dXVuBwAAMC9Dw82hQ4fkcDiarbxERkbKbre3OGbw4MEqLS3VG2+8oRdffFFNTU0aOXKkvv322xb7FxYWKjQ01HVER0e3+jwAAIDvMPyy1NlKSUlRZmam4uPjNWbMGL366qvq2bOnnn322Rb75+Xl6ciRI65j3759Xq4YAAB4UwcjTx4eHi5/f3/V1NS4tdfU1CgqKuqMPiMgIECXXHKJdu/e3eL7QUFBCgoKOu9aAQBA22Doyk1gYKASEhJktVpdbU1NTbJarUpJSTmjz3A4HNqyZYt69erlqTIBAEAbYujKjSTl5uYqKytLiYmJSkpKUnFxserr65WdnS1JyszMVJ8+fVRYWChJevjhh3X55Zdr4MCBOnz4sJ588knt3btXU6dONXIaAADARxgebjIyMnTw4EHl5+fLbrcrPj5eZWVlrk3G1dXV8vP7ZYHphx9+0LRp02S329WtWzclJCTo448/VlxcnFFTAAAAPsTwcCNJOTk5ysnJafE9m83m9nr+/PmaP3++F6oCAABtUZu7WwoAAOB0CDcAAMBUCDcAAMBUCDcAAMBUCDcAAMBUCDcAAMBUCDcAAMBUCDcAAMBUCDcAAMBUfOIbin1FwqznPX6OqiczPX4OAADas3Naubnqqqt0+PDhZu11dXW66qqrzrcmAACAc3ZO4cZms6mxsbFZ+08//aSPPvrovIsCAAA4V2d1Weqf//yn689ffPGF7Ha767XD4VBZWZn69OnTetUBAACcpbMKN/Hx8bJYLLJYLC1efurYsaOefvrpVisOAADgbJ1VuPnmm2/kdDp14YUXqrKyUj179nS9FxgYqIiICPn7+7d6kQAAAGfqrMJNv379JElNTU0eKQYAAOB8nXG4WbNmjcaNG6eAgACtWbPmtH2vu+668y4MAADgXJxxuJk4caLsdrsiIiI0ceLEU/azWCxyOBytURsAAMBZO+Nw8+tLUVyWOnfVDw/z+Dn65m/x+DkAAPBV5/wNxVarVVarVbW1tW5hx2KxaOnSpa1SHAAAwNk6p3AzZ84cPfzww0pMTFSvXr1ksVhauy4AAIBzck7hpqSkRMuXL9ctt9zS2vUAAACcl3N6/EJjY6NGjhzZ2rUAAACct3MKN1OnTtXKlStbuxYAAIDzdsaXpXJzc11/bmpq0uLFi/X+++9r+PDhCggIcOtbVFTUehUCAACchTMON5s2bXJ7HR8fL0naunWrWzubiwEAgJHOONx88MEHnqwDAACgVZzTnhsAAABfRbgBAACmQrgBAACm4hPhZuHChYqJiVFwcLCSk5NVWVl5RuNWrVoli8Vy2gd5AgCA9sXwcLN69Wrl5uaqoKBAGzdu1IgRI5Senq7a2trTjtuzZ4/uuecejR492kuVAgCAtsDwcFNUVKRp06YpOztbcXFxKikpUUhIiEpLS085xuFwaNKkSZozZ44uvPBCL1YLAAB8naHhprGxUVVVVUpLS3O1+fn5KS0tTRUVFacc9/DDDysiIkJTpkzxRpkAAKANOacHZ7aWQ4cOyeFwKDIy0q09MjJSO3bsaHHM+vXrtXTpUm3evPmMztHQ0KCGhgbX67q6unOuFwAA+D7DL0udjaNHj+qWW27RkiVLFB4efkZjCgsLFRoa6jqio6M9XCUAADCSoSs34eHh8vf3V01NjVt7TU2NoqKimvX/6quvtGfPHk2YMMHV1tTUJEnq0KGDdu7cqQEDBriNycvLc3suVl1dHQEHAAATMzTcBAYGKiEhQVar1XU7d1NTk6xWq3Jycpr1j42N1ZYtW9zaHnjgAR09elQLFixoMbQEBQUpKCjII/UDAADfY2i4kX5+2nhWVpYSExOVlJSk4uJi1dfXKzs7W5KUmZmpPn36qLCwUMHBwbr44ovdxoeFhUlSs3YAANA+GR5uMjIydPDgQeXn58tutys+Pl5lZWWuTcbV1dXy82tTW4MAAICBDA83kpSTk9PiZShJstlspx27fPny1i8IAAC0WSyJAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAU/GJcLNw4ULFxMQoODhYycnJqqysPGXfV199VYmJiQoLC1OnTp0UHx+vF154wYvVAgAAX2Z4uFm9erVyc3NVUFCgjRs3asSIEUpPT1dtbW2L/bt37677779fFRUV+uc//6ns7GxlZ2dr3bp1Xq4cAAD4IsPDTVFRkaZNm6bs7GzFxcWppKREISEhKi0tbbF/amqqfv/732vIkCEaMGCAZsyYoeHDh2v9+vVerhwAAPgiQ8NNY2OjqqqqlJaW5mrz8/NTWlqaKioqfnO80+mU1WrVzp07deWVV7bYp6GhQXV1dW4HAAAwL0PDzaFDh+RwOBQZGenWHhkZKbvdfspxR44cUefOnRUYGKjx48fr6aef1jXXXNNi38LCQoWGhrqO6OjoVp0DAADwLYZfljoXXbp00ebNm7VhwwY99thjys3Nlc1ma7FvXl6ejhw54jr27dvn3WIBAIBXdTDy5OHh4fL391dNTY1be01NjaKiok45zs/PTwMHDpQkxcfHa/v27SosLFRqamqzvkFBQQoKCmrVugEAgO8ydOUmMDBQCQkJslqtrrampiZZrValpKSc8ec0NTWpoaHBEyUCAIA2xtCVG0nKzc1VVlaWEhMTlZSUpOLiYtXX1ys7O1uSlJmZqT59+qiwsFDSz3toEhMTNWDAADU0NGjt2rV64YUXtGjRIiOnAQAAfITh4SYjI0MHDx5Ufn6+7Ha74uPjVVZW5tpkXF1dLT+/XxaY6uvrdccdd+jbb79Vx44dFRsbqxdffFEZGRlGTQEAAPgQw8ONJOXk5CgnJ6fF9/59o/Cjjz6qRx991AtVAQCAtqhN3i0FAABwKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKoQbAABgKj4RbhYuXKiYmBgFBwcrOTlZlZWVp+y7ZMkSjR49Wt26dVO3bt2UlpZ22v4AAKB9MTzcrF69Wrm5uSooKNDGjRs1YsQIpaenq7a2tsX+NptNN998sz744ANVVFQoOjpaY8eO1f79+71cOQAA8EWGh5uioiJNmzZN2dnZiouLU0lJiUJCQlRaWtpi/5deekl33HGH4uPjFRsbq+eee05NTU2yWq1erhwAAPgiQ8NNY2OjqqqqlJaW5mrz8/NTWlqaKioqzugzjh8/rhMnTqh79+6eKhMAALQhHYw8+aFDh+RwOBQZGenWHhkZqR07dpzRZ9x7773q3bu3W0D6tYaGBjU0NLhe19XVnXvBAADA5xl+Wep8zJs3T6tWrdJrr72m4ODgFvsUFhYqNDTUdURHR3u5SgAA4E2Ghpvw8HD5+/urpqbGrb2mpkZRUVGnHfvUU09p3rx5evfddzV8+PBT9svLy9ORI0dcx759+1qldgAA4JsMDTeBgYFKSEhw2wx8cnNwSkrKKcc98cQTeuSRR1RWVqbExMTTniMoKEhdu3Z1OwAAgHkZuudGknJzc5WVlaXExEQlJSWpuLhY9fX1ys7OliRlZmaqT58+KiwslCQ9/vjjys/P18qVKxUTEyO73S5J6ty5szp37mzYPAAAgG8wPNxkZGTo4MGDys/Pl91uV3x8vMrKylybjKurq+Xn98sC06JFi9TY2Kg//OEPbp9TUFCghx56yJulAwAAH2R4uJGknJwc5eTktPiezWZze71nzx7PFwQAANqsNn23FAAAwL8j3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMh3AAAAFMxPNwsXLhQMTExCg4OVnJysiorK0/Zd9u2bfrP//xPxcTEyGKxqLi42HuFAgCANsHQcLN69Wrl5uaqoKBAGzdu1IgRI5Senq7a2toW+x8/flwXXnih5s2bp6ioKC9XCwAA2gJDw01RUZGmTZum7OxsxcXFqaSkRCEhISotLW2x/2WXXaYnn3xSN910k4KCgrxcLQAAaAsMCzeNjY2qqqpSWlraL8X4+SktLU0VFRWtdp6GhgbV1dW5HQAAwLwMCzeHDh2Sw+FQZGSkW3tkZKTsdnurnaewsFChoaGuIzo6utU+GwAA+B7DNxR7Wl5eno4cOeI69u3bZ3RJAADAgzoYdeLw8HD5+/urpqbGrb2mpqZVNwsHBQWxPwcAgHbEsJWbwMBAJSQkyGq1utqamppktVqVkpJiVFkAAKCNM2zlRpJyc3OVlZWlxMREJSUlqbi4WPX19crOzpYkZWZmqk+fPiosLJT08ybkL774wvXn/fv3a/PmzercubMGDhxo2DwAAIDvMDTcZGRk6ODBg8rPz5fdbld8fLzKyspcm4yrq6vl5/fL4tJ3332nSy65xPX6qaee0lNPPaUxY8bIZrN5u3wAAOCDDA03kpSTk6OcnJwW3/v3wBITEyOn0+mFqgAAQFtl+rulAABA+0K4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApkK4AQAApuIT4WbhwoWKiYlRcHCwkpOTVVlZedr+f/vb3xQbG6vg4GANGzZMa9eu9VKlAADA1xkeblavXq3c3FwVFBRo48aNGjFihNLT01VbW9ti/48//lg333yzpkyZok2bNmnixImaOHGitm7d6uXKAQCALzI83BQVFWnatGnKzs5WXFycSkpKFBISotLS0hb7L1iwQP/xH/+hWbNmaciQIXrkkUd06aWX6plnnvFy5QAAwBcZGm4aGxtVVVWltLQ0V5ufn5/S0tJUUVHR4piKigq3/pKUnp5+yv4AAKB96WDkyQ8dOiSHw6HIyEi39sjISO3YsaPFMXa7vcX+dru9xf4NDQ1qaGhwvT5y5Igkqa6urllfR8OPZ1X/uTga4PD4OVqa2+kwb89h3p7DvM8M8/Yc5u05Lc37ZJvT6fzN8YaGG28oLCzUnDlzmrVHR0cbUI10sTdOUhjqjbOcFebtQczbZzBvD2LePsPoeR89elShoaf/uRgabsLDw+Xv76+amhq39pqaGkVFRbU4Jioq6qz65+XlKTc31/W6qalJ33//vXr06CGLxXKeMzg7dXV1io6O1r59+9S1a1evnttIzJt5twfMm3m3B0bO2+l06ujRo+rdu/dv9jU03AQGBiohIUFWq1UTJ06U9HP4sFqtysnJaXFMSkqKrFar7r77blfbe++9p5SUlBb7BwUFKSgoyK0tLCysNco/Z127dm1X/zGcxLzbF+bdvjDv9sWoef/Wis1Jhl+Wys3NVVZWlhITE5WUlKTi4mLV19crOztbkpSZmak+ffqosLBQkjRjxgyNGTNGf/7znzV+/HitWrVKn332mRYvXmzkNAAAgI8wPNxkZGTo4MGDys/Pl91uV3x8vMrKylybhqurq+Xn98tNXSNHjtTKlSv1wAMP6L777tOgQYP0+uuv6+KLvXIVEAAA+DjDw40k5eTknPIylM1ma9Z244036sYbb/RwVa0vKChIBQUFzS6TmR3zZt7tAfNm3u1BW5m3xXkm91QBAAC0EYZ/QzEAAEBrItwAAABTIdwAAABTIdx4wYcffqgJEyaod+/eslgsev31140uySsKCwt12WWXqUuXLoqIiNDEiRO1c+dOo8vyuEWLFmn48OGu74FISUnRO++8Y3RZXjdv3jxZLBa376Qyo4ceekgWi8XtiI2NNbosr9i/f78mT56sHj16qGPHjho2bJg+++wzo8vyqJiYmGZ/3xaLRdOnTze6NI9yOBx68MEH1b9/f3Xs2FEDBgzQI488ckaPQjCCT9wtZXb19fUaMWKEbrvtNt1www1Gl+M15eXlmj59ui677DL961//0n333aexY8fqiy++UKdOnYwuz2MuuOACzZs3T4MGDZLT6dSKFSt0/fXXa9OmTRo6dKjR5XnFhg0b9Oyzz2r48OFGl+IVQ4cO1fvvv+963aGD+f/X+sMPP2jUqFH63e9+p3feeUc9e/bUl19+qW7duhldmkdt2LBBDscvz1baunWrrrnmmjZ5B+/ZePzxx7Vo0SKtWLFCQ4cO1Weffabs7GyFhobqrrvuMrq8Zsz/X6APGDdunMaNG2d0GV5XVlbm9nr58uWKiIhQVVWVrrzySoOq8rwJEya4vX7ssce0aNEiffLJJ+0i3Bw7dkyTJk3SkiVL9Oijjxpdjld06NDhlI+AMavHH39c0dHRWrZsmautf//+BlbkHT179nR7PW/ePA0YMEBjxowxqCLv+Pjjj3X99ddr/Pjxkn5ewfrrX/+qyspKgytrGZel4DUnn8jevXt3gyvxHofDoVWrVqm+vv6Ujwgxm+nTp2v8+PFKS0szuhSv+fLLL9W7d29deOGFmjRpkqqrq40uyePWrFmjxMRE3XjjjYqIiNAll1yiJUuWGF2WVzU2NurFF1/Ubbfd5vVnFXrbyJEjZbVatWvXLknS559/rvXr1/vsP9xZuYFXNDU16e6779aoUaPaxbdJb9myRSkpKfrpp5/UuXNnvfbaa4qLizO6LI9btWqVNm7cqA0bNhhditckJydr+fLlGjx4sA4cOKA5c+Zo9OjR2rp1q7p06WJ0eR7z9ddfa9GiRcrNzdV9992nDRs26K677lJgYKCysrKMLs8rXn/9dR0+fFi33nqr0aV43OzZs1VXV6fY2Fj5+/vL4XDoscce06RJk4wurUWEG3jF9OnTtXXrVq1fv97oUrxi8ODB2rx5s44cOaJXXnlFWVlZKi8vN3XA2bdvn2bMmKH33ntPwcHBRpfjNb/+l+vw4cOVnJysfv366eWXX9aUKVMMrMyzmpqalJiYqLlz50qSLrnkEm3dulUlJSXtJtwsXbpU48aNO6OnVLd1L7/8sl566SWtXLlSQ4cO1ebNm3X33Xerd+/ePvn3TbiBx+Xk5Oitt97Shx9+qAsuuMDocrwiMDBQAwcOlCQlJCRow4YNWrBggZ599lmDK/Ocqqoq1dbW6tJLL3W1ORwOffjhh3rmmWfU0NAgf39/Ayv0jrCwMF100UXavXu30aV4VK9evZqF9SFDhujvf/+7QRV51969e/X+++/r1VdfNboUr5g1a5Zmz56tm266SZI0bNgw7d27V4WFhYQbtC9Op1N33nmnXnvtNdlstnax2fBUmpqa1NDQYHQZHnX11Vdry5Ytbm3Z2dmKjY3Vvffe2y6CjfTzhuqvvvpKt9xyi9GleNSoUaOafbXDrl271K9fP4Mq8q5ly5YpIiLCtcHW7I4fP+72EGtJ8vf3V1NTk0EVnR7hxguOHTvm9q+4b775Rps3b1b37t3Vt29fAyvzrOnTp2vlypV644031KVLF9ntdklSaGioOnbsaHB1npOXl6dx48apb9++Onr0qFauXCmbzaZ169YZXZpHdenSpdl+qk6dOqlHjx6m3md1zz33aMKECerXr5++++47FRQUyN/fXzfffLPRpXnUzJkzNXLkSM2dO1d//OMfVVlZqcWLF2vx4sVGl+ZxTU1NWrZsmbKystrFbf/Sz3eBPvbYY+rbt6+GDh2qTZs2qaioSLfddpvRpbXMCY/74IMPnJKaHVlZWUaX5lEtzVmSc9myZUaX5lG33Xabs1+/fs7AwEBnz549nVdffbXz3XffNbosQ4wZM8Y5Y8YMo8vwqIyMDGevXr2cgYGBzj59+jgzMjKcu3fvNrosr3jzzTedF198sTMoKMgZGxvrXLx4sdElecW6deuckpw7d+40uhSvqaurc86YMcPZt29fZ3BwsPPCCy903n///c6GhgajS2sRTwUHAACmwvfcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAGjXYmJiVFxcbHQZAFoR4QaAz0hNTdXdd99tdBkA2jjCDQAAMBXCDQCfcOutt6q8vFwLFiyQxWKRxWLRnj17VF5erqSkJAUFBalXr16aPXu2/vWvf7nGpaamKicnRzk5OQoNDVV4eLgefPBBnetj85577jmFhYXJarW21tQAeBnhBoBPWLBggVJSUjRt2jQdOHBABw4cUEBAgK699lpddtll+vzzz7Vo0SItXbpUjz76qNvYFStWqEOHDqqsrNSCBQtUVFSk55577qxreOKJJzR79my9++67uvrqq1tragC8rIPRBQCAJIWGhiowMFAhISGKioqSJN1///2Kjo7WM888I4vFotjYWH333Xe69957lZ+fLz+/n/99Fh0drfnz58tisWjw4MHasmWL5s+fr2nTpp3x+e+991698MILKi8v19ChQz0yRwDewcoNAJ+1fft2paSkyGKxuNpGjRqlY8eO6dtvv3W1XX755W59UlJS9OWXX8rhcJzRef785z9ryZIlWr9+PcEGMAHCDYB2b/To0XI4HHr55ZeNLgVAKyDcAPAZgYGBbqstQ4YMUUVFhdvm4H/84x/q0qWLLrjgAlfbp59+6vY5n3zyiQYNGiR/f/8zOm9SUpLeeecdzZ07V0899dR5zgKA0Qg3AHxGTEyMPv30U+3Zs0eHDh3SHXfcoX379unOO+/Ujh079MYbb6igoEC5ubmu/TaSVF1drdzcXO3cuVN//etf9fTTT2vGjBlnde6RI0dq7dq1mjNnDl/qB7RxbCgG4DPuueceZWVlKS4uTj/++KO++eYbrV27VrNmzdKIESPUvXt3TZkyRQ888IDbuMzMTP34449KSkqSv7+/ZsyYof/6r/866/NfccUVevvtt3XttdfK399fd955Z2tNDYAXWZzn+mUQAOADUlNTFR8fz2oLABcuSwEAAFMh3AAwrY8++kidO3c+5QHAnLgsBcC0fvzxR+3fv/+U7w8cONCL1QDwFsINAAAwFS5LAQAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAUyHcAAAAU/n/jHHoLcQZHs8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.barplot(x='top_k', y='hit', hue='HyDE', data=hit_stat_df, errorbar=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60f4607e-d1ad-4059-ba9b-0fe4be9cd38c",
   "metadata": {},
   "source": [
    "# 问答全流程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e41380a8-6bac-4e13-8f5b-a8033a44ef6c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:02:46.211715Z",
     "iopub.status.busy": "2024-10-01T14:02:46.210955Z",
     "iopub.status.idle": "2024-10-01T14:02:46.224031Z",
     "shell.execute_reply": "2024-10-01T14:02:46.221735Z",
     "shell.execute_reply.started": "2024-10-01T14:02:46.211643Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain.llms import Ollama\n",
    "\n",
    "llm = Ollama(\n",
    "    model='qwen2:7b-instruct',\n",
    "    base_url='http://localhost:11434'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "15927967-689d-4ec9-ab37-559e405655f4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:03:11.316472Z",
     "iopub.status.busy": "2024-10-01T14:03:11.315695Z",
     "iopub.status.idle": "2024-10-01T14:03:11.330316Z",
     "shell.execute_reply": "2024-10-01T14:03:11.327997Z",
     "shell.execute_reply.started": "2024-10-01T14:03:11.316400Z"
    }
   },
   "outputs": [],
   "source": [
    "def rag(vector_db, query, n_chunks=4):\n",
    "    prompt_tmpl = \"\"\"\n",
    "你是一个金融分析师，擅长根据所获取的信息片段，对问题进行分析和推理。\n",
    "你的任务是根据所获取的信息片段（<<<<context>>><<<</context>>>之间的内容）回答问题。\n",
    "回答保持简洁，不必重复问题，不要添加描述性解释和与答案无关的任何内容。\n",
    "已知信息：\n",
    "<<<<context>>>\n",
    "{{knowledge}}\n",
    "<<<</context>>>\n",
    "\n",
    "问题：{{query}}\n",
    "请回答：\n",
    "\"\"\".strip()\n",
    "\n",
    "    chunks = vector_db.similarity_search(query, k=n_chunks)\n",
    "    prompt = prompt_tmpl.replace('{{knowledge}}', '\\n\\n'.join([doc.page_content for doc in chunks])).replace('{{query}}', query)\n",
    "\n",
    "    return llm.invoke(prompt), chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a4c3545e-b774-4252-86a2-232a4824851e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:03:11.996555Z",
     "iopub.status.busy": "2024-10-01T14:03:11.995786Z",
     "iopub.status.idle": "2024-10-01T14:03:19.514616Z",
     "shell.execute_reply": "2024-10-01T14:03:19.514252Z",
     "shell.execute_reply.started": "2024-10-01T14:03:11.996482Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023年10月美国ISM制造业PMI指数较上个月大幅下降了2.3个百分点。\n"
     ]
    }
   ],
   "source": [
    "print(rag(hyde_vector_db, '2023年10月美国ISM制造业PMI指数较上月有何变化？')[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "832610e4-2a24-4770-b94c-437a114a7b5e",
   "metadata": {},
   "source": [
    "## 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "aaa82763-b110-4122-a97f-2777d2cba74c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:03:28.004970Z",
     "iopub.status.busy": "2024-10-01T14:03:28.003996Z",
     "iopub.status.idle": "2024-10-01T14:03:28.027941Z",
     "shell.execute_reply": "2024-10-01T14:03:28.025502Z",
     "shell.execute_reply.started": "2024-10-01T14:03:28.004895Z"
    }
   },
   "outputs": [],
   "source": [
    "prediction_df = qa_df[qa_df['dataset'] == 'test'][['uuid', 'question', 'qa_type', 'answer']].rename(columns={'answer': 'ref_answer'})\n",
    "\n",
    "def predict(prediction_df, vector_db, n_chunks):\n",
    "    prediction_df = prediction_df.copy()\n",
    "\n",
    "    answer_dict = {}\n",
    "\n",
    "    for idx, row in tqdm(prediction_df.iterrows(), total=len(prediction_df)):\n",
    "        uuid = row['uuid']\n",
    "        question = row['question']\n",
    "        answer, chunks = rag(vector_db, question, n_chunks=n_chunks)\n",
    "        assert len(chunks) <= n_chunks\n",
    "        answer_dict[question] = {\n",
    "            'uuid': uuid,\n",
    "            'ref_answer': row['ref_answer'],\n",
    "            'gen_answer': answer,\n",
    "            'chunks': chunks\n",
    "        }\n",
    "    prediction_df.loc[:, 'gen_answer'] = prediction_df['question'].apply(lambda q: answer_dict[q]['gen_answer'])\n",
    "    prediction_df.loc[:, 'chunks'] = prediction_df['question'].apply(lambda q: answer_dict[q]['chunks'])\n",
    "\n",
    "    return prediction_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8656bd24-3e85-4018-96d2-b3607ce71895",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:03:31.257894Z",
     "iopub.status.busy": "2024-10-01T14:03:31.257721Z",
     "iopub.status.idle": "2024-10-01T14:18:25.943041Z",
     "shell.execute_reply": "2024-10-01T14:18:25.942669Z",
     "shell.execute_reply.started": "2024-10-01T14:03:31.257880Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37bec2425bff412aa4c4a8435c8c4eca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_df = predict(prediction_df, hyde_vector_db, n_chunks=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8dee2e3-13d1-4cf4-8bfd-fb281710ddd9",
   "metadata": {},
   "source": [
    "# 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d7102b2a-c089-4272-8efb-ff9b32378a53",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:19:32.548586Z",
     "iopub.status.busy": "2024-10-01T14:19:32.547819Z",
     "iopub.status.idle": "2024-10-01T14:19:32.601422Z",
     "shell.execute_reply": "2024-10-01T14:19:32.600595Z",
     "shell.execute_reply.started": "2024-10-01T14:19:32.548515Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "judge_llm = ChatOpenAI(\n",
    "    api_key=os.environ['LLM_API_KEY'],\n",
    "    base_url=os.environ['LLM_BASE_URL'],\n",
    "    model_name='qwen2-72b-instruct',\n",
    "    temperature=0\n",
    ")\n",
    "\n",
    "import time\n",
    "\n",
    "def evaluate(prediction_df):\n",
    "    \"\"\"\n",
    "    对预测结果进行打分\n",
    "    :param prediction_df: 预测结果，需要包含问题，参考答案，生成的答案，列名分别为question, ref_answer, gen_answer\n",
    "    :return 打分模型原始返回结果\n",
    "    \"\"\"\n",
    "    prompt_tmpl = \"\"\"\n",
    "你是一个经济学博士，现在我有一系列问题，有一个助手已经对这些问题进行了回答，你需要参照参考答案，评价这个助手的回答是否正确，仅回复“是”或“否”即可，不要带其他描述性内容或无关信息。\n",
    "问题：\n",
    "<question>\n",
    "{{question}}\n",
    "</question>\n",
    "\n",
    "参考答案：\n",
    "<ref_answer>\n",
    "{{ref_answer}}\n",
    "</ref_answer>\n",
    "\n",
    "助手回答：\n",
    "<gen_answer>\n",
    "{{gen_answer}}\n",
    "</gen_answer>\n",
    "请评价：\n",
    "    \"\"\"\n",
    "    results = []\n",
    "\n",
    "    for _, row in tqdm(prediction_df.iterrows(), total=len(prediction_df)):\n",
    "        question = row['question']\n",
    "        ref_answer = row['ref_answer']\n",
    "        gen_answer = row['gen_answer']\n",
    "\n",
    "        prompt = prompt_tmpl.replace('{{question}}', question).replace('{{ref_answer}}', str(ref_answer)).replace('{{gen_answer}}', gen_answer).strip()\n",
    "        result = judge_llm.invoke(prompt).content\n",
    "        results.append(result)\n",
    "\n",
    "        time.sleep(1)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "609e294a-50c2-43d2-8ffd-1ef5061bddf3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:19:33.174058Z",
     "iopub.status.busy": "2024-10-01T14:19:33.173300Z",
     "iopub.status.idle": "2024-10-01T14:22:07.013000Z",
     "shell.execute_reply": "2024-10-01T14:22:07.012553Z",
     "shell.execute_reply.started": "2024-10-01T14:19:33.173986Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d2ac3da121f4d1fb7e46f9b2a464bc8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_df['raw_score'] = evaluate(pred_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "849e5836-f8c0-4ddc-b794-92ea27cefcc2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:22:22.662626Z",
     "iopub.status.busy": "2024-10-01T14:22:22.661868Z",
     "iopub.status.idle": "2024-10-01T14:22:22.668081Z",
     "shell.execute_reply": "2024-10-01T14:22:22.667752Z",
     "shell.execute_reply.started": "2024-10-01T14:22:22.662556Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['是', '否'], dtype=object)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df['raw_score'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5ffc1fde-802f-48cc-af3c-4ecc021cad07",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:22:23.362175Z",
     "iopub.status.busy": "2024-10-01T14:22:23.361621Z",
     "iopub.status.idle": "2024-10-01T14:22:23.367436Z",
     "shell.execute_reply": "2024-10-01T14:22:23.366896Z",
     "shell.execute_reply.started": "2024-10-01T14:22:23.362122Z"
    }
   },
   "outputs": [],
   "source": [
    "pred_df['score'] = (pred_df['raw_score'] == '是').astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "cb8f43a9-9919-46bd-9567-6ec981406828",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-01T14:22:24.283005Z",
     "iopub.status.busy": "2024-10-01T14:22:24.282800Z",
     "iopub.status.idle": "2024-10-01T14:22:24.286338Z",
     "shell.execute_reply": "2024-10-01T14:22:24.286000Z",
     "shell.execute_reply.started": "2024-10-01T14:22:24.282990Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.67"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df['score'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e0bf14f-842a-4c34-94c6-791102bb81d4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
